import numpy as np
from Env import FiniteStateFiniteActionMDP
import matplotlib.pyplot as plt

class Qlearning_gen_AMB:
    def __init__(self, mdp, c, total_episodes):
        self.mdp = mdp
        self.c = c
        self.total_episodes = total_episodes

        self.VU = np.zeros((self.mdp.H + 1, self.mdp.S), dtype = np.float32)
        self.VL = np.zeros((self.mdp.H + 1, self.mdp.S), dtype = np.float32)

        self.QU = np.full((self.mdp.H, self.mdp.S, self.mdp.A), self.mdp.H, dtype = np.float32)        
        for i in range(self.mdp.H):
            self.QU[i,:,:] = self.mdp.H - i
        self.QL = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype = np.float32)

        self.N = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype = np.int32)
        self.n = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype = np.int32)

        self.A = np.full((self.mdp.H, self.mdp.S, self.mdp.A), 1, dtype = np.int32)
        self.G = np.zeros((self.mdp.H, self.mdp.S), dtype = np.int32)

        self.episode_state = np.zeros(self.mdp.H + 1, dtype = np.int32)
        self.episode_action = np.zeros(self.mdp.H + 1, dtype = np.int32)

        self.regret = []
        self.raw_gap = []

    def run_episode(self):
        # Get the policy (actions for all states and steps)
        actions_policy = self.choose_action()
        state = self.mdp.reset()
        state_init = state
        rewards = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))  # To store rewards for each state-step pair        
        self.episode_state[0] = state_init
        for step in range(self.mdp.H):
            # Select the action based on the agent's policy
            action = np.argmax(actions_policy[step, state])

            next_state, reward = self.mdp.step(action)
            
            self.episode_state[step] = state
            self.episode_action[step] = action
            # Increment visit count for the current state-action pair
            self.n[step, state, action] = 1          
            
            # Store the received reward
            rewards[step, state, action] = reward
            state = next_state
        return rewards, state_init
    
    def choose_action(self):
        actions = np.zeros([self.mdp.H, self.mdp.S, self.mdp.A])

        for step in range(self.mdp.H):
            for state in range(self.mdp.S):
                if sum(self.A[step, state, :]) > 1:
                    best_action = np.argmax(self.QU[step, state]- self.QL[step, state])
                    actions[step, state, best_action] = 1
                else:
                    best_action = np.where(self.A[step, state, :] == 1)[0]
                    actions[step, state, best_action] = 1
        return actions
    
    def first_undecided_state(self, step):
        indice = np.zeros(self.mdp.H)
        if step == self.mdp.H - 1:
            return self.mdp.H
        else:
            for i in range(step + 1, self.mdp.H):
                if self.G[i, self.episode_state[i]] == 0:
                    indice[i] = 1
            if sum(indice) == 0:
                return self.mdp.H
            else:
                return min(np.where(indice == 1)[0])
    
    def update_QAMB(self, rewards):
        H = self.mdp.H
        for h in range(H-1,-1,-1):
            for s in range(self.mdp.S):
                for a in range(self.mdp.A):
                    if self.n[h, s, a] == 0:
                        continue
                    else:
                        if self.G[h, s] == 1:
                            continue
                        else:
                            self.N[h, s, a] += 1
                            N_h_k = self.N[h, s, a]
                            step_size = (H + 1) / (H + N_h_k)
                            ucb_bonus = self.c * (H - h - 1) * np.sqrt(H / N_h_k)
                            hprime = self.first_undecided_state(h)
                            sum_rewards = 0
                            for i in range(h, hprime):
                                sum_rewards += rewards[i, self.episode_state[i], self.episode_action[i]]                                   
                            self.QU[h, s, a] = min(H,(1-step_size) * self.QU[h, s, a] + \
                                step_size * (sum_rewards + self.VU[hprime, self.episode_state[hprime]] + ucb_bonus))
                            
                            self.QL[h, s, a] = max(0,(1-step_size) * self.QL[h, s, a] + \
                                step_size * (sum_rewards + self.VL[hprime, self.episode_state[hprime]] - ucb_bonus))
        self.n.fill(0)
        self.episode_state.fill(0)
        self.episode_action.fill(0)

    def learn(self):
        # cummulative regret per-agent
        self.regret_cum = 0
        best_value , best_policy, best_Q = self.mdp.best_gen()

        # Initialize a structure to store rewards (deterministic reward)
        rewards = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))
        for h in range(self.mdp.H):
            for s in range(self.mdp.S):
                self.VU[h,s] = max(self.QU[h, s, :])
                self.VL[h,s] = max(self.QL[h, s, :])
        actions_policy = self.choose_action()
        
        for episode in range(self.total_episodes):
            run_reward, state_init = self.run_episode()
            value = self.mdp.value_gen(actions_policy)
            self.regret_cum = self.regret_cum + best_value[state_init] - value[state_init]
            self.regret.append(self.regret_cum/(episode+1))
            self.raw_gap.append(best_value[state_init] - value[state_init])

            for h in range(self.mdp.H):
                for s in range(self.mdp.S):
                    a = np.argmax(actions_policy[h, s])
                    if rewards[h, s, a] == 0:
                        rewards[h, s, a] =run_reward[h,s,a]

            self.update_QAMB(rewards)
            actions_policy = self.choose_action()
            for h in range(self.mdp.H):
                for s in range(self.mdp.S):
                    self.VU[h,s] = max(self.QU[h, s, :])
                    self.VL[h,s] = max(self.QL[h, s, :])
            for h in range(self.mdp.H):
                for s in range(self.mdp.S):
                    for a in range(self.mdp.A):
                        if self.QU[h, s, a] < self.VL[h, s]:
                            self.A[h, s, a] = 0

            for h in range(self.mdp.H):
                for s in range(self.mdp.S):
                    if sum(self.A[h, s, :]) == 1:
                        self.G[h, s] = 1
        return best_value, best_Q, value, self.QU, self.raw_gap